{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c219841f-493c-40f9-a6c9-3700f0c525d0",
   "metadata": {},
   "source": [
    "# PEFT 库 LoRA 实战 - OpenAI Whisper-large-v2\n",
    "\n",
    "本教程使用 LoRA 在`OpenAI Whisper-large-v2`模型上实现`语音识别(ASR)`任务的微调训练。\n",
    "\n",
    "同时，我们还结合了`int8` 量化进一步降低训练过程资源开销，同时保证了精度几乎不受影响。\n",
    "\n",
    "![LoRA](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/peft/lora_diagram.png)\n",
    "\n",
    "本教程主要训练流程如下：\n",
    "- 全局参数设置\n",
    "- 数据准备\n",
    "    - 下载数据集：训练、验证和评估集\n",
    "    - 预处理数据：降采样、移除不必要字段等\n",
    "    - 数据抽样（演示需要）\n",
    "    - 应用数据集处理（`Dataset.map`）\n",
    "    - 自定义语音数据处理器\n",
    "- 模型准备\n",
    "    - 加载和处理 `int8` 精度 Whisper-Large-v2 模型\n",
    "    - LoRA Adapter 参数配置\n",
    "    - 实例化 PEFT Model：`peft_model = get_peft_model(model, config)`\n",
    "- 模型训练\n",
    "    - 训练参数配置 Seq2SeqTrainingArguments\n",
    "    - 实例化训练器 Seq2SeqTrainer\n",
    "    - 训练模型\n",
    "    - 保存模型\n",
    "- 模型推理\n",
    "    - 使用 `PeftModel` 加载 LoRA 微调后 Whisper 模型\n",
    "    - 使用 `Pipeline API` 部署微调后 Whisper 实现中文语音识别任务"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d0a1e23-ea71-45d6-82d6-453077cf2d29",
   "metadata": {},
   "source": [
    "## 全局参数设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ccd00402-d821-485e-8703-fb16bcb56a9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_or_path = \"openai/whisper-large-v2\"\n",
    "model_dir = \"models/whisper-large-v2-asr-int8\"\n",
    "\n",
    "language = \"Chinese (China)\"\n",
    "language_abbr = \"zh-CN\"\n",
    "task = \"transcribe\"\n",
    "dataset_name = \"mozilla-foundation/common_voice_11_0\"\n",
    "\n",
    "batch_size=64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfffa1df-e51e-4026-9817-1cebddf0061a",
   "metadata": {},
   "source": [
    "## 数据准备\n",
    "\n",
    "### 下载数据集 Common Voice\n",
    "\n",
    "Common Voice 11.0 数据集包含许多不同语言的录音，总时长达数小时。\n",
    "\n",
    "本教程以中文数据为例，展示如何使用 LoRA 在 Whisper-large-v2 上进行微调训练。\n",
    "\n",
    "首先，初始化一个DatasetDict结构，并将训练集（将训练+验证拆分为训练集）和测试集拆分好，按照中文数据集构建配置加载到内存中："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "21ff42f4-f3ec-46d3-b0c0-dd9ffbf7b50b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'client_id': '95368aab163e0387e4fd4991b4f2d8ccfbd4364bf656c860230501fd27dcedf087773e4695a6cf5de9c4f1d406d582283190d065cdfa36b0e2b060cffaca977e',\n",
       " 'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
       " 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
       "  'array': array([-9.09494702e-13, -2.50111043e-12, -2.04636308e-12, ...,\n",
       "          1.21667417e-05,  3.23003815e-06, -2.43064278e-07]),\n",
       "  'sampling_rate': 48000},\n",
       " 'sentence': '性喜温暖润湿气候且耐寒。',\n",
       " 'up_votes': 2,\n",
       " 'down_votes': 0,\n",
       " 'age': '',\n",
       " 'gender': '',\n",
       " 'accent': '',\n",
       " 'locale': 'zh-CN',\n",
       " 'segment': ''}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset, DatasetDict\n",
    "\n",
    "common_voice = DatasetDict()\n",
    "\n",
    "common_voice[\"train\"] = load_dataset(dataset_name, language_abbr, split=\"train\", trust_remote_code=True)\n",
    "common_voice[\"validation\"] = load_dataset(dataset_name, language_abbr, split=\"validation\", trust_remote_code=True)\n",
    "\n",
    "common_voice[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9bed4735-d485-435f-b282-2806241e0e54",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],\n",
       "        num_rows: 29056\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],\n",
       "        num_rows: 10581\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "common_voice"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c81faa4-d8fe-4cc7-afe6-4c2615b9050f",
   "metadata": {},
   "source": [
    "## 预处理训练数据集\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5822025f-7f8e-4141-8bfe-d8822d0da20f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoFeatureExtractor, AutoTokenizer, AutoProcessor\n",
    "\n",
    "# 从预训练模型加载特征提取器\n",
    "feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)\n",
    "\n",
    "# 从预训练模型加载分词器，可以指定语言和任务以获得最适合特定需求的分词器配置\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, language=language, task=task)\n",
    "\n",
    "# 从预训练模型加载处理器，处理器通常结合了特征提取器和分词器，为特定任务提供一站式的数据预处理\n",
    "processor = AutoProcessor.from_pretrained(model_name_or_path, language=language, task=task)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f394e5cd-23b8-413e-8bde-88c3542b84fa",
   "metadata": {},
   "source": [
    "#### 移除数据集中不必要的字段"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1690dc5a-c1f7-4556-9be3-d31ad888e52e",
   "metadata": {},
   "outputs": [],
   "source": [
    "common_voice = common_voice.remove_columns(\n",
    "    [\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"path\", \"segment\", \"up_votes\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "309aff16-ea26-4474-af54-7ef244783999",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
       "  'array': array([-9.09494702e-13, -2.50111043e-12, -2.04636308e-12, ...,\n",
       "          1.21667417e-05,  3.23003815e-06, -2.43064278e-07]),\n",
       "  'sampling_rate': 48000},\n",
       " 'sentence': '性喜温暖润湿气候且耐寒。'}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "common_voice[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "881546ab-72e4-4bcf-852f-a8be736164b7",
   "metadata": {},
   "source": [
    "#### 降采样音频数据\n",
    "\n",
    "查看`common_voice` 数据集介绍，你会发现其音频是以48kHz的采样率进行采样的.\n",
    "\n",
    "而`Whisper`模型是在16kHZ的音频输入上预训练的，因此我们需要将音频输入降采样以匹配模型预训练时使用的采样率。\n",
    "\n",
    "通过在音频列上使用`cast_column`方法，并将`sampling_rate`设置为16kHz来对音频进行降采样。\n",
    "\n",
    "下次调用时，音频输入将实时重新取样："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5fc451cc-e21e-473c-a702-d7d6ed098f91",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Audio\n",
    "\n",
    "common_voice = common_voice.cast_column(\"audio\", Audio(sampling_rate=16000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cc3d7fcc-7c34-41c8-9857-5a6e883f6115",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/dcc5967c754d4c815fc005d6e297d84537028996cbcf6b34190517630cbc40b4/zh-CN_train_0/common_voice_zh-CN_33211332.mp3',\n",
       "  'array': array([ 6.54836185e-11, -2.91038305e-11, -5.82076609e-11, ...,\n",
       "         -5.96660539e-06,  2.71383760e-05,  1.29687833e-05]),\n",
       "  'sampling_rate': 16000},\n",
       " 'sentence': '性喜温暖润湿气候且耐寒。'}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# sampling_rate 从 48KHZ 降为 16KHZ\n",
    "common_voice[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee55908f-3ea3-4aee-8062-6f8d3a6573b9",
   "metadata": {},
   "source": [
    "### 整合以上数据处理为一个函数\n",
    "\n",
    "该数据预处理函数应该包括：\n",
    "- 通过加载音频列将音频输入重新采样为16kHZ。\n",
    "- 使用特征提取器从音频数组计算输入特征。\n",
    "- 将句子列标记化为输入标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "58f42c35-35ba-4d6b-9d15-095963cec67c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_dataset(batch):\n",
    "    audio = batch[\"audio\"]\n",
    "    batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
    "    batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3a04c60-09be-419c-bdc6-6d56bbf9d4d0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a8923262-f881-476f-9806-61a3c8fb8518",
   "metadata": {},
   "source": [
    "### 数据抽样（演示需要）\n",
    "\n",
    "在 Whisper-Large-v2 上使用小规模数据进行演示训练，保持以下训练参数不变（batch_size=64）。\n",
    "\n",
    "使用 640 个样本训练，320个样本验证和评估，恰好使得1个 epoch 仅需10 steps 即可完成训练。\n",
    "\n",
    "（在 NVIDIA T4 上需要10-15分钟）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "28b14693-aa42-4f13-a537-66b5c4cb4718",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_common_voice = DatasetDict()\n",
    "\n",
    "small_common_voice[\"train\"] = common_voice[\"train\"].shuffle(seed=16).select(range(640))\n",
    "small_common_voice[\"validation\"] = common_voice[\"validation\"].shuffle(seed=16).select(range(320))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "46b6a10b-81a9-428d-8130-4c2b626f3ba3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['audio', 'sentence'],\n",
       "        num_rows: 640\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['audio', 'sentence'],\n",
       "        num_rows: 320\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "small_common_voice"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a40f480-6eeb-4546-891b-97ec4a7ec46d",
   "metadata": {},
   "source": [
    "### 如果全量训练，则使用完整数据代替抽样"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "392f7856-a720-40a7-af7e-40e185fc315b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 抽样数据处理\n",
    "tokenized_common_voice = small_common_voice.map(prepare_dataset)\n",
    "\n",
    "# 完整数据训练，尝试开启 `num_proc=8` 参数多进程并行处理（如阻塞无法运行，则不使用此参数）\n",
    "# tokenized_common_voice = common_voice.map(prepare_dataset, num_proc=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "76646516-06e2-4700-92fe-4fbb87587e31",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['audio', 'sentence', 'input_features', 'labels'],\n",
       "        num_rows: 640\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['audio', 'sentence', 'input_features', 'labels'],\n",
       "        num_rows: 320\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_common_voice"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84ec184e-d840-40b6-99af-d11392273442",
   "metadata": {},
   "source": [
    "### 自定义语音数据整理器\n",
    "\n",
    "定义了一个针对语音到文本（Seq2Seq）模型的自定义数据整理器类，特别适用于输入为语音特征、输出为文本序列的数据集。\n",
    "\n",
    "\n",
    "这个整理器（`DataCollatorSpeechSeq2SeqWithPadding`）旨在将数据点批量打包，将每个批次中的`attention_mask`填充到最大长度，以保持批处理中张量形状的一致性，并用`-100`替换填充值，以便在损失函数中被忽略。这对于神经网络的高效训练至关重要。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4c89ffcf-c805-48c2-b7d3-ae01b687178c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Union\n",
    "\n",
    "# 定义一个针对语音到文本任务的数据整理器类\n",
    "@dataclass\n",
    "class DataCollatorSpeechSeq2SeqWithPadding:\n",
    "    processor: Any  # 处理器结合了特征提取器和分词器\n",
    "\n",
    "    # 整理器函数，将特征列表处理成一个批次\n",
    "    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
    "        # 从特征列表中提取输入特征，并填充以使它们具有相同的形状\n",
    "        input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
    "        batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
    "\n",
    "        # 从特征列表中提取标签特征（文本令牌），并进行填充\n",
    "        label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
    "        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
    "\n",
    "        # 使用-100替换标签中的填充区域，-100通常用于在损失计算中忽略填充令牌\n",
    "        labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
    "\n",
    "        # 如果批次中的所有序列都以句子开始令牌开头，则移除它\n",
    "        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
    "            labels = labels[:, 1:]\n",
    "\n",
    "        # 将处理过的标签添加到批次中\n",
    "        batch[\"labels\"] = labels\n",
    "\n",
    "        return batch  # 返回最终的批次，准备好进行训练或评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a26a6b4d-5370-4a48-936a-84739ac0cc2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 用给定的处理器实例化数据整理器\n",
    "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80ecd4bc-01fd-4286-afe5-fe2639ae15a1",
   "metadata": {},
   "source": [
    "## 模型准备\n",
    "\n",
    "### 加载预训练模型（int8 精度）\n",
    "\n",
    "使用 `int8 ` 精度加载预训练模型，进一步降低显存需求。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f9fcb121-fa5c-4c30-8bdc-9ab08ab75427",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSpeechSeq2Seq\n",
    "\n",
    "model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path, load_in_8bit=True, device_map=\"auto\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2cb016f1-e6e9-4fd8-9c8b-72fd23be92d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置模型配置中的forced_decoder_ids属性为None\n",
    "model.config.forced_decoder_ids = None  # 这通常用于指定在解码（生成文本）过程中必须使用的特定token的ID，设置为None表示没有这样的强制要求\n",
    "\n",
    "# 设置模型配置中的suppress_tokens列表为空\n",
    "model.config.suppress_tokens = []  # 这用于指定在生成过程中应被抑制（不生成）的token的列表，设置为空列表表示没有要抑制的token"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25ba1fa0-ea15-48d9-8c16-70df9f0b60b1",
   "metadata": {},
   "source": [
    "### PEFT 微调前的模型处理\n",
    "\n",
    "在使用 `peft` 训练 int8 模型之前，需要进行一些预处理：\n",
    "- 将所有非 `int8` 精度模块转换为全精度（`fp32`）以保证稳定性\n",
    "- 为输入嵌入层添加一个 `forward_hook`，以启用输入隐藏状态的梯度计算\n",
    "- 启用梯度检查点以实现更高效的内存训练\n",
    "\n",
    "使用 `peft` 库预定义的工具函数 `prepare_model_for_int8_training`，便可自动完成以上模型处理工作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1ee34359-fe1b-48f1-827c-6a8ec4a53af7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.11/site-packages/peft/utils/other.py:143: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from peft import prepare_model_for_int8_training\n",
    "\n",
    "model = prepare_model_for_int8_training(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb1212ae-c18c-459b-97a2-4b833c8414ae",
   "metadata": {},
   "source": [
    "### LoRA Adapter 配置\n",
    "\n",
    "在 `peft` 中使用`LoRA`非常简捷，借助 `PeftModel`抽象，我们可以快速使用低秩适配器（LoRA）到任意模型。\n",
    "\n",
    "通过使用 `peft` 中的 `get_peft_model` 工具函数来实现。\n",
    "\n",
    "#### 关于 LoRA 超参数的说明：\n",
    "```\n",
    "MatMul(B,A) * Scaling\n",
    "Scaling = LoRA_Alpha / Rank\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "cdf6bc9c-6d2c-4dbf-b09e-a89cb1041c46",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model\n",
    "\n",
    "# 创建一个LoraConfig对象，用于设置LoRA（Low-Rank Adaptation）的配置参数\n",
    "config = LoraConfig(\n",
    "    r=4,  # LoRA的秩，影响LoRA矩阵的大小\n",
    "    lora_alpha=64,  # LoRA适应的比例因子\n",
    "    # 指定将LoRA应用到的模型模块，通常是attention和全连接层的投影。\n",
    "    target_modules=[\"q_proj\", \"v_proj\"],\n",
    "    lora_dropout=0.05,  # 在LoRA模块中使用的dropout率\n",
    "    bias=\"none\",  # 设置bias的使用方式，这里没有使用bias\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "584652a6-cf07-49d3-a4ee-66f360441fc0",
   "metadata": {},
   "source": [
    "### 使用get_peft_model函数和给定的配置来获取一个PEFT模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7a8f9dc5-6e15-4f16-9ac5-e7492356fe88",
   "metadata": {},
   "outputs": [],
   "source": [
    "peft_model = get_peft_model(model, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "b74c7508-e6f4-42d8-8aaf-fe83c5977c35",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 1,966,080 || all params: 1,545,271,040 || trainable%: 0.12723204856023188\n"
     ]
    }
   ],
   "source": [
    "# 打印 LoRA 微调训练的模型参数\n",
    "peft_model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cc6b26a-3e54-4a46-9b36-a048b40a37d7",
   "metadata": {},
   "source": [
    "## 模型训练\n",
    "\n",
    "#### Seq2SeqTrainingArguments 训练参数\n",
    "\n",
    "**关于设置训练步数和评估步数**\n",
    "\n",
    "基于 epochs 设置：\n",
    "\n",
    "```python\n",
    "    num_train_epochs=3,  # 训练的总轮数\n",
    "    evaluation_strategy=\"epoch\",  # 设置评估策略，这里是在每个epoch结束时进行评估\n",
    "    warmup_steps=50,  # 在训练初期增加学习率的步数，有助于稳定训练\n",
    "```\n",
    "\n",
    "基于 steps 设置：\n",
    "\n",
    "```python\n",
    "    max_steps=100, # 训练总步数\n",
    "    evaluation_strategy=\"steps\", \n",
    "    eval_steps=25, # 评估步数\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "11f259c8-dbcf-4a7f-bbb5-821ab104efee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainingArguments\n",
    "\n",
    "# 设置序列到序列模型训练的参数\n",
    "training_args = Seq2SeqTrainingArguments(\n",
    "    output_dir=model_dir,  # 指定模型输出和保存的目录\n",
    "    per_device_train_batch_size=batch_size,  # 每个设备上的训练批量大小\n",
    "    learning_rate=1e-3,  # 学习率\n",
    "    num_train_epochs=1,  # 训练的总轮数\n",
    "    evaluation_strategy=\"epoch\",  # 设置评估策略，这里是在每个epoch结束时进行评估\n",
    "    # warmup_steps=50,  # 在训练初期增加学习率的步数，有助于稳定训练\n",
    "    # fp16=True,  # 启用混合精度训练，可以提高训练速度，同时减少内存使用\n",
    "    per_device_eval_batch_size=batch_size,  # 每个设备上的评估批量大小\n",
    "    generation_max_length=128,  # 生成任务的最大长度\n",
    "    logging_steps=10,  # 指定日志记录的步骤，用于跟踪训练进度\n",
    "    remove_unused_columns=False,  # 是否删除不使用的列，以减少数据处理开销\n",
    "    label_names=[\"labels\"],  # 指定标签列的名称，用于训练过程中\n",
    "    # evaluation_strategy=\"steps\",\n",
    "    # eval_steps=25,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c57ee183-b16f-4313-97f6-0df6c0f5f467",
   "metadata": {},
   "source": [
    "### 实例化 Seq2SeqTrainer 训练器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "f8a52ed7-cae0-4aba-818e-87717430d908",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 4.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    }
   ],
   "source": [
    "from transformers import Seq2SeqTrainer\n",
    "\n",
    "trainer = Seq2SeqTrainer(\n",
    "    args=training_args,\n",
    "    model=peft_model,\n",
    "    train_dataset=tokenized_common_voice[\"train\"],\n",
    "    eval_dataset=tokenized_common_voice[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=processor.feature_extractor,\n",
    ")\n",
    "peft_model.config.use_cache = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "6973bed7-8f53-4d55-966c-f037941e5ef3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.11/site-packages/torch/utils/checkpoint.py:91: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [10/10 09:31, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.502400</td>\n",
       "      <td>1.081281</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=10, training_loss=1.5024151802062988, metrics={'train_runtime': 623.8671, 'train_samples_per_second': 1.026, 'train_steps_per_second': 0.016, 'total_flos': 1.36064139264e+18, 'train_loss': 1.5024151802062988, 'epoch': 1.0})"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "620992c3-64f5-48f9-8e66-fdc5f6a27427",
   "metadata": {},
   "source": [
    "### 保存 LoRA 模型(Adapter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "53310565-7313-46a7-acf1-215970fd4f8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model(model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "412785f2-f66f-492a-b01a-06ca81d9aed7",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PeftModel(\n",
       "  (base_model): LoraModel(\n",
       "    (model): WhisperForConditionalGeneration(\n",
       "      (model): WhisperModel(\n",
       "        (encoder): WhisperEncoder(\n",
       "          (conv1): Conv1d(80, 1280, kernel_size=(3,), stride=(1,), padding=(1,))\n",
       "          (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))\n",
       "          (embed_positions): Embedding(1500, 1280)\n",
       "          (layers): ModuleList(\n",
       "            (0-31): 32 x WhisperEncoderLayer(\n",
       "              (self_attn): WhisperSdpaAttention(\n",
       "                (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "                (v_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (q_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "              )\n",
       "              (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "              (activation_fn): GELUActivation()\n",
       "              (fc1): Linear8bitLt(in_features=1280, out_features=5120, bias=True)\n",
       "              (fc2): Linear8bitLt(in_features=5120, out_features=1280, bias=True)\n",
       "              (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "            )\n",
       "          )\n",
       "          (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "        (decoder): WhisperDecoder(\n",
       "          (embed_tokens): Embedding(51865, 1280, padding_idx=50257)\n",
       "          (embed_positions): WhisperPositionalEmbedding(448, 1280)\n",
       "          (layers): ModuleList(\n",
       "            (0-31): 32 x WhisperDecoderLayer(\n",
       "              (self_attn): WhisperSdpaAttention(\n",
       "                (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "                (v_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (q_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "              )\n",
       "              (activation_fn): GELUActivation()\n",
       "              (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "              (encoder_attn): WhisperSdpaAttention(\n",
       "                (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)\n",
       "                (v_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (q_proj): lora.Linear8bitLt(\n",
       "                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "                  (lora_dropout): ModuleDict(\n",
       "                    (default): Dropout(p=0.05, inplace=False)\n",
       "                  )\n",
       "                  (lora_A): ModuleDict(\n",
       "                    (default): Linear(in_features=1280, out_features=4, bias=False)\n",
       "                  )\n",
       "                  (lora_B): ModuleDict(\n",
       "                    (default): Linear(in_features=4, out_features=1280, bias=False)\n",
       "                  )\n",
       "                  (lora_embedding_A): ParameterDict()\n",
       "                  (lora_embedding_B): ParameterDict()\n",
       "                )\n",
       "                (out_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=True)\n",
       "              )\n",
       "              (encoder_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "              (fc1): Linear8bitLt(in_features=1280, out_features=5120, bias=True)\n",
       "              (fc2): Linear8bitLt(in_features=5120, out_features=1280, bias=True)\n",
       "              (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "            )\n",
       "          )\n",
       "          (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "      (proj_out): Linear(in_features=1280, out_features=51865, bias=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "peft_model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcfe9611-eee5-462f-8cb8-fed86eec76e0",
   "metadata": {},
   "source": [
    "## 模型推理（可能需要重启 Notebook）\n",
    "\n",
    "**再次加载模型会额外占用显存，如果显存已经达到上限，建议重启 Notebook 后再进行以下操作**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "bd2890f5-2eb9-493d-b43b-266fb12c6ac6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir = \"models/whisper-large-v2-asr-int8\"\n",
    "\n",
    "language = \"Chinese (China)\"\n",
    "language_abbr = \"zh-CN\"\n",
    "language_decode = \"chinese\"\n",
    "task = \"transcribe\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5ad8bc8-420b-4a98-b83f-08303693221b",
   "metadata": {},
   "source": [
    "\n",
    "### 使用 `PeftModel` 加载 LoRA 微调后 Whisper 模型\n",
    "\n",
    "使用 `PeftConfig` 加载 LoRA Adapter 配置参数，使用 `PeftModel` 加载微调后 Whisper 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9d7f3af5-af01-4c26-80e9-976686983178",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoProcessor\n",
    "from peft import PeftConfig, PeftModel\n",
    "\n",
    "peft_config = PeftConfig.from_pretrained(model_dir)\n",
    "\n",
    "base_model = AutoModelForSpeechSeq2Seq.from_pretrained(\n",
    "    peft_config.base_model_name_or_path, load_in_8bit=True, device_map=\"auto\"\n",
    ")\n",
    "\n",
    "peft_model = PeftModel.from_pretrained(base_model, model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e3686334-d8d1-4782-834b-187aa684fb77",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
    "processor = AutoProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
    "feature_extractor = processor.feature_extractor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c31e558c-0c7b-445c-8210-52bd04fc0dd7",
   "metadata": {},
   "source": [
    "### 使用 Pipeline API 部署微调后 Whisper 实现中文语音识别任务"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "18181692-a143-44ee-b56c-e754d308e0ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_audio = \"data/audio/test_zh.flac\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "9d494647-082c-4e48-9486-7945618ae679",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutomaticSpeechRecognitionPipeline\n",
    "\n",
    "pipeline = AutomaticSpeechRecognitionPipeline(model=peft_model, tokenizer=tokenizer, feature_extractor=feature_extractor)\n",
    "\n",
    "forced_decoder_ids = processor.get_decoder_prompt_ids(language=language_decode, task=task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "90da1707-9054-416f-b0b6-a6203f8d3285",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "with torch.cuda.amp.autocast():\n",
    "    text = pipeline(test_audio, max_new_tokens=255)[\"text\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "89f49787-6ab4-4bc1-91b8-a1c104c9feaf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'这是一段测试用于WhisperLarge V2模型的自动语音识别测试。'"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0285dd19-229e-4241-b680-71e25ab51dde",
   "metadata": {},
   "source": [
    "## Homework\n",
    "\n",
    "1. 使用完整的数据集训练，对比 Train Loss 和 Validation Loss 变化。训练完成后，使用测试集进行模型评估.\n",
    "2. [Optional]使用其他语种（如：德语、法语等）的数据集进行微调训练，并进行模型评估模型评估。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8801650e-6666-412a-981f-1f8933d5df55",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
